import os, sys
from typing import Optional, Tuple, Union, List
import torch
def wrap_w_funcs(original_func):
    from .weight_grad_store import WeightGradStore
    def wrapped_func(total_input, grad_output, weight):
        from megatron.training import get_args
        if os.getenv("ENABLE_ZERO_BUBBLE", "0") == "1":
            WeightGradStore.put((total_input, grad_output, weight), original_func)
        else:
            original_func(total_input, grad_output, weight)
    return wrapped_func

def wrap_w_funcs_gemm(original_func):
    def wrapped_func(
        A: torch.Tensor,
        B: torch.Tensor,
        dtype: torch.dtype,
        workspace: torch.Tensor,
        layout: str = "TN",
        **kwargs
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        import functools
        from musa_patch.zbb_light.weight_grad_store import WeightGradStore
        from megatron.training import get_args
        if layout == "NT" and os.getenv("ENABLE_ZERO_BUBBLE", "0") == "1":
            WeightGradStore.put(
                                (A, B, dtype, workspace),
                                functools.partial(
                                    original_func,
                                    layout="NT",
                                    **kwargs)
                                )
            return (None, None, None)
        else:
            return original_func(A, B, dtype, workspace, layout=layout, **kwargs)


    return wrapped_func

def wrap_w_general_gemm(original_func):
    def wrapped_func(
        A: torch.Tensor,
        B: torch.Tensor,
        workspace: torch.Tensor,
        layout: str = "TN",
        **kwargs
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        import functools
        from musa_patch.zbb_light.weight_grad_store import WeightGradStore
        from megatron.training import get_args
        if layout == "NT" and os.getenv("ENABLE_ZERO_BUBBLE", "0") == "1":
            from transformer_engine.pytorch.utils import clear_tensor_data
            WeightGradStore.put(
                                (A, B, workspace),
                                functools.partial(
                                    original_func,
                                    layout="NT",
                                    **kwargs),
                                clear_tensor_data
                                )
            return (None, None, None, None)
        else:
            return original_func(A, B, workspace, layout=layout, **kwargs)


    return wrapped_func

def wrap_w_general_grouped_gemm(original_func):
    def wrapped_func(
        A: List[torch.Tensor],
        B: List[torch.Tensor],
        out: List[torch.Tensor],
        out_dtype: torch.dtype,
        workspaces: List[torch.Tensor],
        layout: str = "TN",
        use_bias: bool = False,
        **kwargs
    ) -> Tuple[Union[torch.Tensor, None], ...]:

        import functools
        from musa_patch.zbb_light.weight_grad_store import WeightGradStore
        from megatron.training import get_args
        if layout == "NT" and os.getenv("ENABLE_ZERO_BUBBLE", "0") == "1":
            from transformer_engine.pytorch.utils import clear_tensor_data
            WeightGradStore.put(
                                (A, B, out, out_dtype, workspaces),
                                functools.partial(
                                    original_func,
                                    layout="NT",
                                    use_bias = False,
                                    **kwargs),
                                clear_tensor_data
                                )
            assert use_bias== False, "Zero-bubble doesn't support the case where bias is used."
            return (None, [None] * len(A), None)
        else:
            return original_func(A, B, out, out_dtype, workspaces, use_bias = use_bias, layout=layout, **kwargs)


    return wrapped_func

    
def patch_megatron():
    assert all([not x.startswith('megatron') for x in sys.modules.keys()]), 'Please patch zbpp before importing any megatron modules.'
    import fused_weight_gradient_mlp_cuda
    assert hasattr(fused_weight_gradient_mlp_cuda, 'wgrad_gemm_accum_fp32')
    assert hasattr(fused_weight_gradient_mlp_cuda, 'wgrad_gemm_accum_fp16')
    fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32 = wrap_w_funcs(fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32)
    fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16 = wrap_w_funcs(fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16)

    import transformer_engine.pytorch.cpp_extensions
    transformer_engine.pytorch.cpp_extensions.gemm = wrap_w_funcs_gemm(transformer_engine.pytorch.cpp_extensions.gemm)
    transformer_engine.pytorch.cpp_extensions.general_gemm = wrap_w_general_gemm(transformer_engine.pytorch.cpp_extensions.general_gemm)
    transformer_engine.pytorch.cpp_extensions.general_grouped_gemm = wrap_w_general_grouped_gemm(transformer_engine.pytorch.cpp_extensions.general_grouped_gemm)
    
    import megatron.core.pipeline_parallel
    from .zb_schedule import get_zero_bubble_forward_backward_func
    assert hasattr(megatron.core.pipeline_parallel.schedules, 'get_forward_backward_func')
    assert hasattr(megatron.core.pipeline_parallel, 'get_forward_backward_func')
    megatron.core.pipeline_parallel.schedules.get_forward_backward_func_origin = megatron.core.pipeline_parallel.schedules.get_forward_backward_func
    megatron.core.pipeline_parallel.get_forward_backward_func_origin = megatron.core.pipeline_parallel.get_forward_backward_func
    megatron.core.pipeline_parallel.schedules.get_forward_backward_func = get_zero_bubble_forward_backward_func
    megatron.core.pipeline_parallel.get_forward_backward_func = get_zero_bubble_forward_backward_func
